# mssql.py

"""MSSQL backend, thru either pymssq, adodbapi or pyodbc interfaces.

* ``IDENTITY`` columns are supported by using SA ``schema.Sequence()``
  objects. In other words::

    Table('test', mss_engine,
           Column('id',   Integer, Sequence('blah',100,10), primary_key=True),
           Column('name', String(20))
         ).create()

  would yield::

   CREATE TABLE test (
     id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
     name VARCHAR(20)
     )

  Note that the start & increment values for sequences are optional
  and will default to 1,1.

* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
  ``INSERT`` s)

* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``

* ``select._limit`` implemented as ``SELECT TOP n``

* Experimental implemention of LIMIT / OFFSET with row_number()

Known issues / TODO:

* No support for more than one ``IDENTITY`` column per table

* pymssql has problems with binary and unicode data that this module
  does **not** work around

"""
import datetime, operator, re, sys

from sqlalchemy import sql, schema, exc, util
from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions
from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions
from sqlalchemy.engine import default, base
from sqlalchemy import types as sqltypes
from decimal import Decimal as _python_Decimal


MSSQL_RESERVED_WORDS = set(['function'])

class MSNumeric(sqltypes.Numeric):
    def result_processor(self, dialect):
        if self.asdecimal:
            def process(value):
                if value is not None:
                    return _python_Decimal(str(value))
                else:
                    return value
            return process
        else:
            def process(value):
                return float(value)
            return process

    def bind_processor(self, dialect):
        def process(value):
            if value is None:
                # Not sure that this exception is needed
                return value
            else:
                if not isinstance(value, float) and value._exp < -6:
                    value = ((value < 0 and '-' or '')
                        + '0.'
                        + '0' * -(value._exp+1)
                        + value._int)
                    return value
                else:
                    return str(value)
        return process

    def get_col_spec(self):
        if self.precision is None:
            return "NUMERIC"
        else:
            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}

class MSFloat(sqltypes.Float):
    def get_col_spec(self):
        return "FLOAT(%(precision)s)" % {'precision': self.precision}

    def bind_processor(self, dialect):
        def process(value):
            """By converting to string, we can use Decimal types round-trip."""
            if not value is None:
                return str(value)
            return None
        return process

class MSInteger(sqltypes.Integer):
    def get_col_spec(self):
        return "INTEGER"

class MSBigInteger(MSInteger):
    def get_col_spec(self):
        return "BIGINT"

class MSTinyInteger(MSInteger):
    def get_col_spec(self):
        return "TINYINT"

class MSSmallInteger(MSInteger):
    def get_col_spec(self):
        return "SMALLINT"

class MSDateTime(sqltypes.DateTime):
    def __init__(self, *a, **kw):
        super(MSDateTime, self).__init__(False)

    def get_col_spec(self):
        return "DATETIME"

class MSSmallDate(sqltypes.Date):
    def __init__(self, *a, **kw):
        super(MSSmallDate, self).__init__(False)

    def get_col_spec(self):
        return "SMALLDATETIME"

    def result_processor(self, dialect):
        def process(value):
            # If the DBAPI returns the value as datetime.datetime(), truncate it back to datetime.date()
            if type(value) is datetime.datetime:
                return value.date()
            return value
        return process

class MSDate(sqltypes.Date):
    def __init__(self, *a, **kw):
        super(MSDate, self).__init__(False)

    def get_col_spec(self):
        return "DATETIME"

    def result_processor(self, dialect):
        def process(value):
            # If the DBAPI returns the value as datetime.datetime(), truncate it back to datetime.date()
            if type(value) is datetime.datetime:
                return value.date()
            return value
        return process

class MSTime(sqltypes.Time):
    __zero_date = datetime.date(1900, 1, 1)

    def __init__(self, *a, **kw):
        super(MSTime, self).__init__(False)

    def get_col_spec(self):
        return "DATETIME"

    def bind_processor(self, dialect):
        def process(value):
            if type(value) is datetime.datetime:
                value = datetime.datetime.combine(self.__zero_date, value.time())
            elif type(value) is datetime.time:
                value = datetime.datetime.combine(self.__zero_date, value)
            return value
        return process

    def result_processor(self, dialect):
        def process(value):
            if type(value) is datetime.datetime:
                return value.time()
            elif type(value) is datetime.date:
                return datetime.time(0, 0, 0)
            return value
        return process

class MSDateTime_adodbapi(MSDateTime):
    def result_processor(self, dialect):
        def process(value):
            # adodbapi will return datetimes with empty time values as datetime.date() objects.
            # Promote them back to full datetime.datetime()
            if type(value) is datetime.date:
                return datetime.datetime(value.year, value.month, value.day)
            return value
        return process

class MSDateTime_pyodbc(MSDateTime):
    def bind_processor(self, dialect):
        def process(value):
            if type(value) is datetime.date:
                return datetime.datetime(value.year, value.month, value.day)
            return value
        return process

class MSDate_pyodbc(MSDate):
    def bind_processor(self, dialect):
        def process(value):
            if type(value) is datetime.date:
                return datetime.datetime(value.year, value.month, value.day)
            return value
        return process

class MSText(sqltypes.Text):
    def get_col_spec(self):
        if self.dialect.text_as_varchar:
            return "VARCHAR(max)"
        else:
            return "TEXT"

class MSString(sqltypes.String):
    def get_col_spec(self):
        return "VARCHAR" + (self.length and "(%d)" % self.length or "")

class MSNVarchar(sqltypes.Unicode):
    def get_col_spec(self):
        if self.length:
            return "NVARCHAR(%(length)s)" % {'length' : self.length}
        elif self.dialect.text_as_varchar:
            return "NVARCHAR(max)"
        else:
            return "NTEXT"

class AdoMSNVarchar(MSNVarchar):
    """overrides bindparam/result processing to not convert any unicode strings"""
    def bind_processor(self, dialect):
        return None

    def result_processor(self, dialect):
        return None

class MSChar(sqltypes.CHAR):
    def get_col_spec(self):
        return "CHAR(%(length)s)" % {'length' : self.length}

class MSNChar(sqltypes.NCHAR):
    def get_col_spec(self):
        return "NCHAR(%(length)s)" % {'length' : self.length}

class MSBinary(sqltypes.Binary):
    def get_col_spec(self):
        return "IMAGE"

class MSBoolean(sqltypes.Boolean):
    def get_col_spec(self):
        return "BIT"

    def result_processor(self, dialect):
        def process(value):
            if value is None:
                return None
            return value and True or False
        return process

    def bind_processor(self, dialect):
        def process(value):
            if value is True:
                return 1
            elif value is False:
                return 0
            elif value is None:
                return None
            else:
                return value and True or False
        return process

class MSTimeStamp(sqltypes.TIMESTAMP):
    def get_col_spec(self):
        return "TIMESTAMP"

class MSMoney(sqltypes.TypeEngine):
    def get_col_spec(self):
        return "MONEY"

class MSSmallMoney(MSMoney):
    def get_col_spec(self):
        return "SMALLMONEY"

class MSUniqueIdentifier(sqltypes.TypeEngine):
    def get_col_spec(self):
        return "UNIQUEIDENTIFIER"

class MSVariant(sqltypes.TypeEngine):
    def get_col_spec(self):
        return "SQL_VARIANT"

class MSSQLExecutionContext(default.DefaultExecutionContext):
    def __init__(self, *args, **kwargs):
        self.IINSERT = self.HASIDENT = False
        super(MSSQLExecutionContext, self).__init__(*args, **kwargs)

    def _has_implicit_sequence(self, column):
        if column.primary_key and column.autoincrement:
            if isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
                if column.default is None or (isinstance(column.default, schema.Sequence) and \
                                              column.default.optional):
                    return True
        return False

    def pre_exec(self):
        """MS-SQL has a special mode for inserting non-NULL values
        into IDENTITY columns.

        Activate it if the feature is turned on and needed.
        """
        if self.compiled.isinsert:
            tbl = self.compiled.statement.table
            if not hasattr(tbl, 'has_sequence'):
                tbl.has_sequence = None
                for column in tbl.c:
                    if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
                        tbl.has_sequence = column
                        break

            self.HASIDENT = bool(tbl.has_sequence)
            if self.dialect.auto_identity_insert and self.HASIDENT:
                if isinstance(self.compiled_parameters, list):
                    self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
                else:
                    self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
            else:
                self.IINSERT = False

            if self.IINSERT:
                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))

        super(MSSQLExecutionContext, self).pre_exec()

    def post_exec(self):
        """Turn off the INDENTITY_INSERT mode if it's been activated,
        and fetch recently inserted IDENTIFY values (works only for
        one column).
        """

        if self.compiled.isinsert and (not self.executemany) and self.HASIDENT and not self.IINSERT:
            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
                if self.dialect.use_scope_identity:
                    self.cursor.execute("SELECT scope_identity() AS lastrowid")
                else:
                    self.cursor.execute("SELECT @@identity AS lastrowid")
                row = self.cursor.fetchone()
                self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
        super(MSSQLExecutionContext, self).post_exec()


class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
    def pre_exec(self):
        """where appropriate, issue "select scope_identity()" in the same statement"""
        super(MSSQLExecutionContext_pyodbc, self).pre_exec()
        if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) \
                and len(self.parameters) == 1 and self.dialect.use_scope_identity:
            self.statement += "; select scope_identity()"

    def post_exec(self):
        if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
            # do nothing - id was fetched in dialect.do_execute()
            pass
        else:
            super(MSSQLExecutionContext_pyodbc, self).post_exec()

class MSSQLDialect(default.DefaultDialect):
    name = 'mssql'
    supports_default_values = True
    supports_empty_insert = False

    colspecs = {
        sqltypes.Unicode : MSNVarchar,
        sqltypes.Integer : MSInteger,
        sqltypes.Smallinteger: MSSmallInteger,
        sqltypes.Numeric : MSNumeric,
        sqltypes.Float : MSFloat,
        sqltypes.DateTime : MSDateTime,
        sqltypes.Date : MSDate,
        sqltypes.Time : MSTime,
        sqltypes.String : MSString,
        sqltypes.Binary : MSBinary,
        sqltypes.Boolean : MSBoolean,
        sqltypes.Text : MSText,
        sqltypes.CHAR: MSChar,
        sqltypes.NCHAR: MSNChar,
        sqltypes.TIMESTAMP: MSTimeStamp,
    }

    ischema_names = {
        'int' : MSInteger,
        'bigint': MSBigInteger,
        'smallint' : MSSmallInteger,
        'tinyint' : MSTinyInteger,
        'varchar' : MSString,
        'nvarchar' : MSNVarchar,
        'char' : MSChar,
        'nchar' : MSNChar,
        'text' : MSText,
        'ntext' : MSText,
        'decimal' : MSNumeric,
        'numeric' : MSNumeric,
        'float' : MSFloat,
        'datetime' : MSDateTime,
        'date': MSDate,
        'smalldatetime' : MSSmallDate,
        'binary' : MSBinary,
        'varbinary' : MSBinary,
        'bit': MSBoolean,
        'real' : MSFloat,
        'image' : MSBinary,
        'timestamp': MSTimeStamp,
        'money': MSMoney,
        'smallmoney': MSSmallMoney,
        'uniqueidentifier': MSUniqueIdentifier,
        'sql_variant': MSVariant,
    }

    def __new__(cls, dbapi=None, *args, **kwargs):
        if cls != MSSQLDialect:
            return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
        if dbapi:
            dialect = dialect_mapping.get(dbapi.__name__)
            return dialect(*args, **kwargs)
        else:
            return object.__new__(cls, *args, **kwargs)

    def __init__(self, auto_identity_insert=True, **params):
        super(MSSQLDialect, self).__init__(**params)
        self.auto_identity_insert = auto_identity_insert
        self.text_as_varchar = False
        self.use_scope_identity = False
        self.has_window_funcs = False
        self.set_default_schema_name("dbo")

    def dbapi(cls, module_name=None):
        if module_name:
            try:
                dialect_cls = dialect_mapping[module_name]
                return dialect_cls.import_dbapi()
            except KeyError:
                raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
        else:
            for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
                try:
                    return dialect_cls.import_dbapi()
                except ImportError, e:
                    pass
            else:
                raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
    dbapi = classmethod(dbapi)

    def server_version_info(self, connection):
        """A tuple of the database server version.

        Formats the remote server version as a tuple of version values,
        e.g. ``(9, 0, 1399)``.  If there are strings in the version number
        they will be in the tuple too, so don't count on these all being
        ``int`` values.

        This is a fast check that does not require a round trip.  It is also
        cached per-Connection.
        """
        return connection.dialect._server_version_info(connection.connection)
    server_version_info = base.connection_memoize(
        ('mssql', 'server_version_info'))(server_version_info)

    def _server_version_info(self, dbapi_con):
        """Return a tuple of the database's version number."""

        raise NotImplementedError()
    
    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='user')
        opts.update(url.query)
        if 'auto_identity_insert' in opts:
            self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
        if 'query_timeout' in opts:
            self.query_timeout = int(opts.pop('query_timeout'))
        if 'text_as_varchar' in opts:
            self.text_as_varchar = bool(int(opts.pop('text_as_varchar')))
        if 'use_scope_identity' in opts:
            self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
        if 'has_window_funcs' in opts:
            self.has_window_funcs =  bool(int(opts.pop('has_window_funcs')))
        return self.make_connect_string(opts)

    def create_execution_context(self, *args, **kwargs):
        return MSSQLExecutionContext(self, *args, **kwargs)

    def type_descriptor(self, typeobj):
        newobj = sqltypes.adapt_type(typeobj, self.colspecs)
        # Some types need to know about the dialect
        if isinstance(newobj, (MSText, MSNVarchar)):
            newobj.dialect = self
        return newobj

    def last_inserted_ids(self):
        return self.context.last_inserted_ids

    def get_default_schema_name(self, connection):
        return self.schema_name

    def set_default_schema_name(self, schema_name):
        self.schema_name = schema_name

    def last_inserted_ids(self):
        return self.context.last_inserted_ids

    def do_execute(self, cursor, statement, params, context=None, **kwargs):
        if params == {}:
            params = ()
        try:
            super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs)
        finally:
            if context.IINSERT:
                cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))

    def do_executemany(self, cursor, statement, params, context=None, **kwargs):
        try:
            super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs)
        finally:
            if context.IINSERT:
                cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))

    def _execute(self, c, statement, parameters):
        try:
            if parameters == {}:
                parameters = ()
            c.execute(statement, parameters)
            self.context.rowcount = c.rowcount
            c.DBPROP_COMMITPRESERVE = "Y"
        except Exception, e:
            raise exc.DBAPIError.instance(statement, parameters, e)

    def table_names(self, connection, schema):
        from sqlalchemy.databases import information_schema as ischema
        return ischema.table_names(connection, schema)

    def raw_connection(self, connection):
        """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
        try:
            # TODO: probably want to move this to individual dialect subclasses to
            # save on the exception throw + simplify
            return connection.connection.__dict__['_pymssqlCnx__cnx']
        except:
            return connection.connection.adoConn

    def uppercase_table(self, t):
        # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
        t.name = t.name.upper()
        if t.schema:
            t.schema = t.schema.upper()
        for c in t.columns:
            c.name = c.name.upper()
        return t

    def has_table(self, connection, tablename, schema=None):
        import sqlalchemy.databases.information_schema as ischema

        current_schema = schema or self.get_default_schema_name(connection)
        columns = self.uppercase_table(ischema.columns)
        s = sql.select([columns],
                   current_schema
                       and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
                       or columns.c.table_name==tablename,
                   )

        c = connection.execute(s)
        row  = c.fetchone()
        return row is not None

    def reflecttable(self, connection, table, include_columns):
        import sqlalchemy.databases.information_schema as ischema
        # Get base columns
        if table.schema is not None:
            current_schema = table.schema
        else:
            current_schema = self.get_default_schema_name(connection)

        columns = self.uppercase_table(ischema.columns)
        s = sql.select([columns],
                   current_schema
                       and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema)
                       or columns.c.table_name==table.name,
                   order_by=[columns.c.ordinal_position])

        c = connection.execute(s)
        found_table = False
        while True:
            row = c.fetchone()
            if row is None:
                break
            found_table = True
            (name, type, nullable, charlen, numericprec, numericscale, default) = (
                row[columns.c.column_name],
                row[columns.c.data_type],
                row[columns.c.is_nullable] == 'YES',
                row[columns.c.character_maximum_length],
                row[columns.c.numeric_precision],
                row[columns.c.numeric_scale],
                row[columns.c.column_default]
            )
            if include_columns and name not in include_columns:
                continue

            args = []
            for a in (charlen, numericprec, numericscale):
                if a is not None:
                    args.append(a)
            coltype = self.ischema_names.get(type, None)
            if coltype == MSText or (coltype == MSString and charlen == -1):
                coltype = MSText()
            else:
                if coltype is None:
                    util.warn("Did not recognize type '%s' of column '%s'" %
                              (type, name))
                    coltype = sqltypes.NULLTYPE

                elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1:
                    args[0] = None
                coltype = coltype(*args)
            colargs = []
            if default is not None:
                colargs.append(schema.DefaultClause(sql.text(default)))

            table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))

        if not found_table:
            raise exc.NoSuchTableError(table.name)

        # We also run an sp_columns to check for identity columns:
        cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema))
        ic = None
        while True:
            row = cursor.fetchone()
            if row is None:
                break
            col_name, type_name = row[3], row[5]
            if type_name.endswith("identity"):
                ic = table.c[col_name]
                ic.autoincrement = True
                # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
                ic.sequence = schema.Sequence(ic.name + '_identity')
                # MSSQL: only one identity per table allowed
                cursor.close()
                break
        if not ic is None:
            try:
                cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname)
                row = cursor.fetchone()
                cursor.close()
                if not row is None:
                    ic.sequence.start = int(row[0])
                    ic.sequence.increment = int(row[1])
            except:
                # ignoring it, works just like before
                pass

        # Add constraints
        RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
        TC = self.uppercase_table(ischema.constraints)        #information_schema.table_constraints
        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column
        R  = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column

        # Primary key constraints
        s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name,
                                                                         C.c.table_name == table.name,
                                                                         C.c.table_schema == (table.schema or current_schema)))
        c = connection.execute(s)
        for row in c:
            if 'PRIMARY' in row[TC.c.constraint_type.name]:
                table.primary_key.add(table.c[row[0]])

        # Foreign key constraints
        s = sql.select([C.c.column_name,
                        R.c.table_schema, R.c.table_name, R.c.column_name,
                        RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
                       sql.and_(C.c.table_name == table.name,
                                C.c.table_schema == (table.schema or current_schema),
                                C.c.constraint_name == RR.c.constraint_name,
                                R.c.constraint_name == RR.c.unique_constraint_name,
                                C.c.ordinal_position == R.c.ordinal_position
                                ),
                       order_by = [RR.c.constraint_name, R.c.ordinal_position])
        rows = connection.execute(s).fetchall()

        def _gen_fkref(table, rschema, rtbl, rcol):
            if rschema == current_schema and not table.schema:
                return '.'.join([rtbl, rcol])
            else:
                return '.'.join([rschema, rtbl, rcol])

        # group rows by constraint ID, to handle multi-column FKs
        fknm, scols, rcols = (None, [], [])
        for r in rows:
            scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
            # if the reflected schema is the default schema then don't set it because this will
            # play into the metadata key causing duplicates.
            if rschema == current_schema and not table.schema:
                schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection)
            else:
                schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection)
            if rfknm != fknm:
                if fknm:
                    table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm))
                fknm, scols, rcols = (rfknm, [], [])
            if not scol in scols:
                scols.append(scol)
            if not (rschema, rtbl, rcol) in rcols:
                rcols.append((rschema, rtbl, rcol))

        if fknm and scols:
            table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm))


class MSSQLDialect_pymssql(MSSQLDialect):
    supports_sane_rowcount = False
    max_identifier_length = 30

    def import_dbapi(cls):
        import pymssql as module
        # pymmsql doesn't have a Binary method.  we use string
        # TODO: monkeypatching here is less than ideal
        module.Binary = lambda st: str(st)
        return module
    import_dbapi = classmethod(import_dbapi)

    ischema_names = MSSQLDialect.ischema_names.copy()


    def __init__(self, **params):
        super(MSSQLDialect_pymssql, self).__init__(**params)
        self.use_scope_identity = True

        # pymssql understands only ascii
        if self.convert_unicode:
            self.encoding = params.get('encoding', 'ascii')

    def do_rollback(self, connection):
        # pymssql throws an error on repeated rollbacks. Ignore it.
        # TODO: this is normal behavior for most DBs.  are we sure we want to ignore it ?
        try:
            connection.rollback()
        except:
            pass

    def create_connect_args(self, url):
        r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
        if hasattr(self, 'query_timeout'):
            self.dbapi._mssql.set_query_timeout(self.query_timeout)
        return r

    def make_connect_string(self, keys):
        if keys.get('port'):
            # pymssql expects port as host:port, not a separate arg
            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
            del keys['port']
        return [[], keys]

    def is_disconnect(self, e):
        return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)


class MSSQLDialect_pyodbc(MSSQLDialect):
    supports_sane_rowcount = False
    supports_sane_multi_rowcount = False
    # PyODBC unicode is broken on UCS-4 builds
    supports_unicode = sys.maxunicode == 65535
    supports_unicode_statements = supports_unicode

    def __init__(self, **params):
        super(MSSQLDialect_pyodbc, self).__init__(**params)
        # whether use_scope_identity will work depends on the version of pyodbc
        try:
            import pyodbc
            self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
        except:
            pass

    def import_dbapi(cls):
        import pyodbc as module
        return module
    import_dbapi = classmethod(import_dbapi)

    colspecs = MSSQLDialect.colspecs.copy()
    if supports_unicode:
        colspecs[sqltypes.Unicode] = AdoMSNVarchar
    colspecs[sqltypes.Date] = MSDate_pyodbc
    colspecs[sqltypes.DateTime] = MSDateTime_pyodbc

    ischema_names = MSSQLDialect.ischema_names.copy()
    if supports_unicode:
        ischema_names['nvarchar'] = AdoMSNVarchar
    ischema_names['smalldatetime'] = MSDate_pyodbc
    ischema_names['datetime'] = MSDateTime_pyodbc

    def make_connect_string(self, keys):
        if 'max_identifier_length' in keys:
            self.max_identifier_length = int(keys.pop('max_identifier_length'))
        if 'dsn' in keys:
            connectors = ['dsn=%s' % keys.pop('dsn')]
        else:
            port = ''
            if 'port' in keys and (
                keys.get('driver', 'SQL Server') == 'SQL Server'):
                port = ',%d' % int(keys.pop('port'))

            connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
                          'Server=%s%s' % (keys.pop('host', ''), port),
                          'Database=%s' % keys.pop('database', '') ]

            if 'port' in keys and not port:
                connectors.append('Port=%d' % int(keys.pop('port')))

        user = keys.pop("user", None)
        if user:
            connectors.append("UID=%s" % user)
            connectors.append("PWD=%s" % keys.pop('password', ''))
        else:
            connectors.append("TrustedConnection=Yes")

        # if set to 'Yes', the ODBC layer will try to automagically convert 
        # textual data from your database encoding to your client encoding 
        # This should obviously be set to 'No' if you query a cp1253 encoded 
        # database from a latin1 client... 
        if 'odbc_autotranslate' in keys:
            connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))

        # Allow specification of partial ODBC connect string
        if 'odbc_options' in keys: 
            odbc_options=keys.pop('odbc_options')
            if odbc_options[0]=="'" and odbc_options[-1]=="'":
                odbc_options=odbc_options[1:-1]
            connectors.append(odbc_options)
        connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
        return [[";".join (connectors)], {}]

    def is_disconnect(self, e):
        if isinstance(e, self.dbapi.ProgrammingError):
            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
        elif isinstance(e, self.dbapi.Error):
            return '[08S01]' in str(e)
        else:
            return False

    def create_execution_context(self, *args, **kwargs):
        return MSSQLExecutionContext_pyodbc(self, *args, **kwargs)

    def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
        super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
        if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity:
            import pyodbc
            # Fetch the last inserted id from the manipulated statement
            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
            while True:
                try:
                    row = cursor.fetchone()
                    break
                except pyodbc.Error, e:
                    cursor.nextset()
            context._last_inserted_ids = [int(row[0])]

    def _server_version_info(self, dbapi_con):
        """Convert a pyodbc SQL_DBMS_VER string into a tuple."""

        version = []
        r = re.compile('[.\-]')
        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
            try:
                version.append(int(n))
            except ValueError:
                version.append(n)
        return tuple(version)

class MSSQLDialect_adodbapi(MSSQLDialect):
    supports_sane_rowcount = True
    supports_sane_multi_rowcount = True
    supports_unicode = sys.maxunicode == 65535
    supports_unicode_statements = True

    def import_dbapi(cls):
        import adodbapi as module
        return module
    import_dbapi = classmethod(import_dbapi)

    colspecs = MSSQLDialect.colspecs.copy()
    colspecs[sqltypes.Unicode] = AdoMSNVarchar
    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi

    ischema_names = MSSQLDialect.ischema_names.copy()
    ischema_names['nvarchar'] = AdoMSNVarchar
    ischema_names['datetime'] = MSDateTime_adodbapi

    def make_connect_string(self, keys):
        connectors = ["Provider=SQLOLEDB"]
        if 'port' in keys:
            connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
        else:
            connectors.append ("Data Source=%s" % keys.get("host"))
        connectors.append ("Initial Catalog=%s" % keys.get("database"))
        user = keys.get("user")
        if user:
            connectors.append("User Id=%s" % user)
            connectors.append("Password=%s" % keys.get("password", ""))
        else:
            connectors.append("Integrated Security=SSPI")
        return [[";".join (connectors)], {}]

    def is_disconnect(self, e):
        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)

dialect_mapping = {
    'pymssql':  MSSQLDialect_pymssql,
    'pyodbc':   MSSQLDialect_pyodbc,
    'adodbapi': MSSQLDialect_adodbapi
    }


class MSSQLCompiler(compiler.DefaultCompiler):
    operators = compiler.OPERATORS.copy()
    operators.update({
        sql_operators.concat_op: '+',
        sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
    })

    functions = compiler.DefaultCompiler.functions.copy()
    functions.update (
        {
            sql_functions.now: 'CURRENT_TIMESTAMP',
            sql_functions.current_date: 'GETDATE()',
            'length': lambda x: "LEN(%s)" % x
        }
    )

    def __init__(self, *args, **kwargs):
        super(MSSQLCompiler, self).__init__(*args, **kwargs)
        self.tablealiases = {}

    def get_select_precolumns(self, select):
        """ MS-SQL puts TOP, it's version of LIMIT here """
        if select._distinct or select._limit:
            s = select._distinct and "DISTINCT " or ""
            
            if select._limit:
                if not select._offset:
                    s += "TOP %s " % (select._limit,)
                else:
                    if not self.dialect.has_window_funcs:
                        raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
            return s
        return compiler.DefaultCompiler.get_select_precolumns(self, select)

    def limit_clause(self, select):
        # Limit in mssql is after the select keyword
        return ""

    def visit_select(self, select, **kwargs):
        """Look for ``LIMIT`` and OFFSET in a select statement, and if
        so tries to wrap it in a subquery with ``row_number()`` criterion.

        """
        if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and select._offset:
            # to use ROW_NUMBER(), an ORDER BY is required.
            orderby = self.process(select._order_by_clause)
            if not orderby:
                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')

            _offset = select._offset
            _limit = select._limit
            select._mssql_visit = True
            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()

            limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
            limitselect.append_whereclause("mssql_rn>%d" % _offset)
            if _limit is not None:
                limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
            return self.process(limitselect, iswrapper=True, **kwargs)
        else:
            return compiler.DefaultCompiler.visit_select(self, select, **kwargs)

    def _schema_aliased_table(self, table):
        if getattr(table, 'schema', None) is not None:
            if table not in self.tablealiases:
                self.tablealiases[table] = table.alias()
            return self.tablealiases[table]
        else:
            return None

    def visit_table(self, table, mssql_aliased=False, **kwargs):
        if mssql_aliased:
            return super(MSSQLCompiler, self).visit_table(table, **kwargs)

        # alias schema-qualified tables
        alias = self._schema_aliased_table(table)
        if alias is not None:
            return self.process(alias, mssql_aliased=True, **kwargs)
        else:
            return super(MSSQLCompiler, self).visit_table(table, **kwargs)

    def visit_alias(self, alias, **kwargs):
        # translate for schema-qualified table aliases
        self.tablealiases[alias.original] = alias
        kwargs['mssql_aliased'] = True
        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)

    def visit_column(self, column, result_map=None, **kwargs):
        if column.table is not None and \
            (not self.isupdate and not self.isdelete) or self.is_subquery():
            # translate for schema-qualified table aliases
            t = self._schema_aliased_table(column.table)
            if t is not None:
                converted = expression._corresponding_column_or_error(t, column)

                if result_map is not None:
                    result_map[column.name.lower()] = (column.name, (column, ), column.type)

                return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)

        return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)

    def visit_binary(self, binary, **kwargs):
        """Move bind parameters to the right-hand side of an operator, where
        possible.

        """
        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
            and not isinstance(binary.right, expression._BindParamClause):
            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
        else:
            if (binary.operator in (operator.eq, operator.ne)) and (
                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
                op = binary.operator == operator.eq and "IN" or "NOT IN"
                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
            return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)

    def label_select_column(self, select, column, asfrom):
        if isinstance(column, expression._Function):
            return column.label(None)
        else:
            return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)

    def for_update_clause(self, select):
        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
        return ''

    def order_by_clause(self, select):
        order_by = self.process(select._order_by_clause)

        # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
        if order_by and (not self.is_subquery() or select._limit):
            return " ORDER BY " + order_by
        else:
            return ""


class MSSQLSchemaGenerator(compiler.SchemaGenerator):
    def get_column_specification(self, column, **kwargs):
        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()

        # install a IDENTITY Sequence if we have an implicit IDENTITY column
        if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
                column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
            if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
                column.sequence = schema.Sequence(column.name + '_seq')

        if not column.nullable:
            colspec += " NOT NULL"

        if hasattr(column, 'sequence'):
            column.table.has_sequence = column
            colspec += " IDENTITY(%s,%s)" % (column.sequence.start or 1, column.sequence.increment or 1)
        else:
            default = self.get_column_default_string(column)
            if default is not None:
                colspec += " DEFAULT " + default

        return colspec

class MSSQLSchemaDropper(compiler.SchemaDropper):
    def visit_index(self, index):
        self.append("\nDROP INDEX %s.%s" % (
            self.preparer.quote_identifier(index.table.name),
            self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
            ))
        self.execute()


class MSSQLDefaultRunner(base.DefaultRunner):
    # TODO: does ms-sql have standalone sequences ?
    # A: No, only auto-incrementing IDENTITY property of a column
    pass

class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
    reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)

    def __init__(self, dialect):
        super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')

    def _escape_identifier(self, value):
        #TODO: determin MSSQL's escapeing rules
        return value

dialect = MSSQLDialect
dialect.statement_compiler = MSSQLCompiler
dialect.schemagenerator = MSSQLSchemaGenerator
dialect.schemadropper = MSSQLSchemaDropper
dialect.preparer = MSSQLIdentifierPreparer
dialect.defaultrunner = MSSQLDefaultRunner
